convenience.go 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349
  1. package testhelper
  2. import (
  3. "bytes"
  4. "encoding/json"
  5. "fmt"
  6. "path/filepath"
  7. "reflect"
  8. "runtime"
  9. "strings"
  10. "testing"
  11. )
  12. const (
  13. logBodyFmt = "\033[1;31m%s %s\033[0m"
  14. greenCode = "\033[0m\033[1;32m"
  15. yellowCode = "\033[0m\033[1;33m"
  16. resetCode = "\033[0m\033[1;31m"
  17. )
  18. func prefix(depth int) string {
  19. _, file, line, _ := runtime.Caller(depth)
  20. return fmt.Sprintf("Failure in %s, line %d:", filepath.Base(file), line)
  21. }
  22. func green(str interface{}) string {
  23. return fmt.Sprintf("%s%#v%s", greenCode, str, resetCode)
  24. }
  25. func yellow(str interface{}) string {
  26. return fmt.Sprintf("%s%#v%s", yellowCode, str, resetCode)
  27. }
  28. func logFatal(t *testing.T, str string) {
  29. t.Fatalf(logBodyFmt, prefix(3), str)
  30. }
  31. func logError(t *testing.T, str string) {
  32. t.Errorf(logBodyFmt, prefix(3), str)
  33. }
  34. type diffLogger func([]string, interface{}, interface{})
  35. type visit struct {
  36. a1 uintptr
  37. a2 uintptr
  38. typ reflect.Type
  39. }
  40. // Recursively visits the structures of "expected" and "actual". The diffLogger function will be
  41. // invoked with each different value encountered, including the reference path that was followed
  42. // to get there.
  43. func deepDiffEqual(expected, actual reflect.Value, visited map[visit]bool, path []string, logDifference diffLogger) {
  44. defer func() {
  45. // Fall back to the regular reflect.DeepEquals function.
  46. if r := recover(); r != nil {
  47. var e, a interface{}
  48. if expected.IsValid() {
  49. e = expected.Interface()
  50. }
  51. if actual.IsValid() {
  52. a = actual.Interface()
  53. }
  54. if !reflect.DeepEqual(e, a) {
  55. logDifference(path, e, a)
  56. }
  57. }
  58. }()
  59. if !expected.IsValid() && actual.IsValid() {
  60. logDifference(path, nil, actual.Interface())
  61. return
  62. }
  63. if expected.IsValid() && !actual.IsValid() {
  64. logDifference(path, expected.Interface(), nil)
  65. return
  66. }
  67. if !expected.IsValid() && !actual.IsValid() {
  68. return
  69. }
  70. hard := func(k reflect.Kind) bool {
  71. switch k {
  72. case reflect.Array, reflect.Map, reflect.Slice, reflect.Struct:
  73. return true
  74. }
  75. return false
  76. }
  77. if expected.CanAddr() && actual.CanAddr() && hard(expected.Kind()) {
  78. addr1 := expected.UnsafeAddr()
  79. addr2 := actual.UnsafeAddr()
  80. if addr1 > addr2 {
  81. addr1, addr2 = addr2, addr1
  82. }
  83. if addr1 == addr2 {
  84. // References are identical. We can short-circuit
  85. return
  86. }
  87. typ := expected.Type()
  88. v := visit{addr1, addr2, typ}
  89. if visited[v] {
  90. // Already visited.
  91. return
  92. }
  93. // Remember this visit for later.
  94. visited[v] = true
  95. }
  96. switch expected.Kind() {
  97. case reflect.Array:
  98. for i := 0; i < expected.Len(); i++ {
  99. hop := append(path, fmt.Sprintf("[%d]", i))
  100. deepDiffEqual(expected.Index(i), actual.Index(i), visited, hop, logDifference)
  101. }
  102. return
  103. case reflect.Slice:
  104. if expected.IsNil() != actual.IsNil() {
  105. logDifference(path, expected.Interface(), actual.Interface())
  106. return
  107. }
  108. if expected.Len() == actual.Len() && expected.Pointer() == actual.Pointer() {
  109. return
  110. }
  111. for i := 0; i < expected.Len(); i++ {
  112. hop := append(path, fmt.Sprintf("[%d]", i))
  113. deepDiffEqual(expected.Index(i), actual.Index(i), visited, hop, logDifference)
  114. }
  115. return
  116. case reflect.Interface:
  117. if expected.IsNil() != actual.IsNil() {
  118. logDifference(path, expected.Interface(), actual.Interface())
  119. return
  120. }
  121. deepDiffEqual(expected.Elem(), actual.Elem(), visited, path, logDifference)
  122. return
  123. case reflect.Ptr:
  124. deepDiffEqual(expected.Elem(), actual.Elem(), visited, path, logDifference)
  125. return
  126. case reflect.Struct:
  127. for i, n := 0, expected.NumField(); i < n; i++ {
  128. field := expected.Type().Field(i)
  129. hop := append(path, "."+field.Name)
  130. deepDiffEqual(expected.Field(i), actual.Field(i), visited, hop, logDifference)
  131. }
  132. return
  133. case reflect.Map:
  134. if expected.IsNil() != actual.IsNil() {
  135. logDifference(path, expected.Interface(), actual.Interface())
  136. return
  137. }
  138. if expected.Len() == actual.Len() && expected.Pointer() == actual.Pointer() {
  139. return
  140. }
  141. var keys []reflect.Value
  142. if expected.Len() >= actual.Len() {
  143. keys = expected.MapKeys()
  144. } else {
  145. keys = actual.MapKeys()
  146. }
  147. for _, k := range keys {
  148. expectedValue := expected.MapIndex(k)
  149. actualValue := actual.MapIndex(k)
  150. if !expectedValue.IsValid() {
  151. logDifference(path, nil, actual.Interface())
  152. return
  153. }
  154. if !actualValue.IsValid() {
  155. logDifference(path, expected.Interface(), nil)
  156. return
  157. }
  158. hop := append(path, fmt.Sprintf("[%v]", k))
  159. deepDiffEqual(expectedValue, actualValue, visited, hop, logDifference)
  160. }
  161. return
  162. case reflect.Func:
  163. if expected.IsNil() != actual.IsNil() {
  164. logDifference(path, expected.Interface(), actual.Interface())
  165. }
  166. return
  167. default:
  168. if expected.Interface() != actual.Interface() {
  169. logDifference(path, expected.Interface(), actual.Interface())
  170. }
  171. }
  172. }
  173. func deepDiff(expected, actual interface{}, logDifference diffLogger) {
  174. if expected == nil || actual == nil {
  175. logDifference([]string{}, expected, actual)
  176. return
  177. }
  178. expectedValue := reflect.ValueOf(expected)
  179. actualValue := reflect.ValueOf(actual)
  180. if expectedValue.Type() != actualValue.Type() {
  181. logDifference([]string{}, expected, actual)
  182. return
  183. }
  184. deepDiffEqual(expectedValue, actualValue, map[visit]bool{}, []string{}, logDifference)
  185. }
  186. // AssertEquals compares two arbitrary values and performs a comparison. If the
  187. // comparison fails, a fatal error is raised that will fail the test
  188. func AssertEquals(t *testing.T, expected, actual interface{}) {
  189. if expected != actual {
  190. logFatal(t, fmt.Sprintf("expected %s but got %s", green(expected), yellow(actual)))
  191. }
  192. }
  193. // CheckEquals is similar to AssertEquals, except with a non-fatal error
  194. func CheckEquals(t *testing.T, expected, actual interface{}) {
  195. if expected != actual {
  196. logError(t, fmt.Sprintf("expected %s but got %s", green(expected), yellow(actual)))
  197. }
  198. }
  199. // AssertDeepEquals - like Equals - performs a comparison - but on more complex
  200. // structures that requires deeper inspection
  201. func AssertDeepEquals(t *testing.T, expected, actual interface{}) {
  202. pre := prefix(2)
  203. differed := false
  204. deepDiff(expected, actual, func(path []string, expected, actual interface{}) {
  205. differed = true
  206. t.Errorf("\033[1;31m%sat %s expected %s, but got %s\033[0m",
  207. pre,
  208. strings.Join(path, ""),
  209. green(expected),
  210. yellow(actual))
  211. })
  212. if differed {
  213. logFatal(t, "The structures were different.")
  214. }
  215. }
  216. // CheckDeepEquals is similar to AssertDeepEquals, except with a non-fatal error
  217. func CheckDeepEquals(t *testing.T, expected, actual interface{}) {
  218. pre := prefix(2)
  219. deepDiff(expected, actual, func(path []string, expected, actual interface{}) {
  220. t.Errorf("\033[1;31m%s at %s expected %s, but got %s\033[0m",
  221. pre,
  222. strings.Join(path, ""),
  223. green(expected),
  224. yellow(actual))
  225. })
  226. }
  227. func isByteArrayEquals(t *testing.T, expectedBytes []byte, actualBytes []byte) bool {
  228. return bytes.Equal(expectedBytes, actualBytes)
  229. }
  230. // AssertByteArrayEquals a convenience function for checking whether two byte arrays are equal
  231. func AssertByteArrayEquals(t *testing.T, expectedBytes []byte, actualBytes []byte) {
  232. if !isByteArrayEquals(t, expectedBytes, actualBytes) {
  233. logFatal(t, "The bytes differed.")
  234. }
  235. }
  236. // CheckByteArrayEquals a convenience function for silent checking whether two byte arrays are equal
  237. func CheckByteArrayEquals(t *testing.T, expectedBytes []byte, actualBytes []byte) {
  238. if !isByteArrayEquals(t, expectedBytes, actualBytes) {
  239. logError(t, "The bytes differed.")
  240. }
  241. }
  242. // isJSONEquals is a utility function that implements JSON comparison for AssertJSONEquals and
  243. // CheckJSONEquals.
  244. func isJSONEquals(t *testing.T, expectedJSON string, actual interface{}) bool {
  245. var parsedExpected, parsedActual interface{}
  246. err := json.Unmarshal([]byte(expectedJSON), &parsedExpected)
  247. if err != nil {
  248. t.Errorf("Unable to parse expected value as JSON: %v", err)
  249. return false
  250. }
  251. jsonActual, err := json.Marshal(actual)
  252. AssertNoErr(t, err)
  253. err = json.Unmarshal(jsonActual, &parsedActual)
  254. AssertNoErr(t, err)
  255. if !reflect.DeepEqual(parsedExpected, parsedActual) {
  256. prettyExpected, err := json.MarshalIndent(parsedExpected, "", " ")
  257. if err != nil {
  258. t.Logf("Unable to pretty-print expected JSON: %v\n%s", err, expectedJSON)
  259. } else {
  260. // We can't use green() here because %#v prints prettyExpected as a byte array literal, which
  261. // is... unhelpful. Converting it to a string first leaves "\n" uninterpreted for some reason.
  262. t.Logf("Expected JSON:\n%s%s%s", greenCode, prettyExpected, resetCode)
  263. }
  264. prettyActual, err := json.MarshalIndent(actual, "", " ")
  265. if err != nil {
  266. t.Logf("Unable to pretty-print actual JSON: %v\n%#v", err, actual)
  267. } else {
  268. // We can't use yellow() for the same reason.
  269. t.Logf("Actual JSON:\n%s%s%s", yellowCode, prettyActual, resetCode)
  270. }
  271. return false
  272. }
  273. return true
  274. }
  275. // AssertJSONEquals serializes a value as JSON, parses an expected string as JSON, and ensures that
  276. // both are consistent. If they aren't, the expected and actual structures are pretty-printed and
  277. // shown for comparison.
  278. //
  279. // This is useful for comparing structures that are built as nested map[string]interface{} values,
  280. // which are a pain to construct as literals.
  281. func AssertJSONEquals(t *testing.T, expectedJSON string, actual interface{}) {
  282. if !isJSONEquals(t, expectedJSON, actual) {
  283. logFatal(t, "The generated JSON structure differed.")
  284. }
  285. }
  286. // CheckJSONEquals is similar to AssertJSONEquals, but nonfatal.
  287. func CheckJSONEquals(t *testing.T, expectedJSON string, actual interface{}) {
  288. if !isJSONEquals(t, expectedJSON, actual) {
  289. logError(t, "The generated JSON structure differed.")
  290. }
  291. }
  292. // AssertNoErr is a convenience function for checking whether an error value is
  293. // an actual error
  294. func AssertNoErr(t *testing.T, e error) {
  295. if e != nil {
  296. logFatal(t, fmt.Sprintf("unexpected error %s", yellow(e.Error())))
  297. }
  298. }
  299. // CheckNoErr is similar to AssertNoErr, except with a non-fatal error
  300. func CheckNoErr(t *testing.T, e error) {
  301. if e != nil {
  302. logError(t, fmt.Sprintf("unexpected error %s", yellow(e.Error())))
  303. }
  304. }